package com.chatplus.application.aiprocessor.platform.image;

import cn.bugstack.openai.session.OpenAiSession;
import cn.hutool.core.util.RandomUtil;
import com.chatplus.application.common.enumeration.UserStatusEnum;
import com.chatplus.application.common.exception.BadRequestException;
import com.chatplus.application.common.logging.SouthernQuietLogger;
import com.chatplus.application.common.logging.SouthernQuietLoggerFactory;
import com.chatplus.application.common.util.OkHttpClientUtil;
import com.chatplus.application.domain.dto.ApiKeyDto;
import com.chatplus.application.domain.dto.extend.ImgResultDto;
import com.chatplus.application.domain.entity.account.UserEntity;
import com.chatplus.application.domain.vo.file.UpLoadFileVo;
import com.chatplus.application.enumeration.AiPlatformEnum;
import com.chatplus.application.enumeration.ImageTaskTypeEnum;
import com.chatplus.application.service.account.UserProductLogService;
import com.chatplus.application.service.account.UserService;
import com.chatplus.application.service.file.FileService;
import com.chatplus.application.service.provider.FileServiceProvider;
import com.chatplus.application.util.ConfigUtil;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.tika.Tika;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.web.multipart.MultipartFile;

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.zip.GZIPInputStream;

/**
 * Img 类型 AI 处理器服务接口
 */
public abstract class ImgAiProcessorService {
    private static final SouthernQuietLogger LOGGER = SouthernQuietLoggerFactory.getLogger(ImgAiProcessorService.class);
    private FileServiceProvider fileServiceProvider;
    private UserService userService;
    private UserProductLogService userProductLogService;
    private final Tika tika;
    // 限制同时运行的任务数量
    private static final long MAX_RUNNING_JOB_COUNT = 5;
    protected List<ApiKeyDto> openApiKey = new ArrayList<>();

    protected ImgAiProcessorService() {
        tika = new Tika();
    }

    @Autowired
    public void setFileServiceProvider(FileServiceProvider fileServiceProvider) {
        this.fileServiceProvider = fileServiceProvider;
    }

    @Autowired
    public void setUserService(UserService userService) {
        this.userService = userService;
    }

    @Autowired
    public void setUserProductLogService(UserProductLogService userProductLogService) {
        this.userProductLogService = userProductLogService;
    }

    public abstract ImgResultDto process(Object prompt);

    /**
     * AI 渠道处理器
     */
    public abstract AiPlatformEnum getChannel();

    /**
     * 初始化会话工厂
     */
    public synchronized OpenAiSession getSessionFactory() {
        throw new UnsupportedOperationException("不支持的操作");
    }

    /**
     * 获取画图进度
     */
    public int getImageProgress(Long id) {
        throw new UnsupportedOperationException("不支持的操作");
    }

    public long getRunningJobCount(Long userId) {
        throw new UnsupportedOperationException("不支持的操作");
    }

    public void notifyUpdateTask(Long jobId) {
        throw new UnsupportedOperationException("不支持的操作");
    }

    public void instance() {
        openApiKey = ConfigUtil.getImageApiKey(getChannel());
        // // 更新 API KEY 的最后使用时间
        if (CollectionUtils.isEmpty(openApiKey)) {
            throw new BadRequestException("抱歉😔😔😔，系统已经没有可用的 API KEY，请联系管理员！");
        }
    }

    /**
     * 校验用户请求
     * 1. 敏感字处理
     * 2. 用户信息校验
     * 3. 套餐余额校验
     *
     * @param userId 用户ID
     */
    public String verifyUserRequest(Long userId, ImageTaskTypeEnum imageTaskType) {
        instance();
        if (getRunningJobCount(userId) >= MAX_RUNNING_JOB_COUNT) {
            return "您当前正在进行的任务数量已经达到上限，请稍后再试！";
        }
        // 用户信息校验
        UserEntity userEntity = userService.getById(userId);
        if (userEntity == null) {
            return "非法用户，请联系管理员！";
        }
        if (userEntity.getStatus() != UserStatusEnum.OK) {
            return "您的账号已经被禁用，如果疑问，请联系管理员！";
        }
        // 套餐余额校验
        int power = ConfigUtil.getImagePower(userId, getChannel(), imageTaskType);
        int imageCalls = userProductLogService.getUserChatPower(userId);
        if (imageCalls <= 0) {
            return String.format("您当前剩余绘画算力（%d）已不足以支付当前绘画需要消耗的绘画算力（%d）！"
                    , imageCalls, power);
        }
        return null;
    }


    // 保存到OSS
    public String saveToOss(String url, String base64) throws IOException {
        String resultUrl = null;
        if (StringUtils.isNotEmpty(base64)) {
            InputStream inputStream = base2InputStream(base64);
            if (inputStream != null) {
                resultUrl = uploadFile(inputStream);
            }
        }
        if (StringUtils.isEmpty(resultUrl) && StringUtils.isNotEmpty(url)) {
            if (url.startsWith("https://cdn.discordapp.com")) {
                url = url.replaceFirst("https://cdn.discordapp.com", "https://discord.renrenai.online");
            }
            OkHttpClient client = OkHttpClientUtil.getOkHttpClient(null, true);
            Request request = new Request.Builder()
                    .addHeader("Accept", "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7")
                    .addHeader("Accept-Encoding", "gzip, deflate, br, zstd")
                    .addHeader("Accept-Language", "zh-CN,zh;q=0.9,en-US;q=0.8,en;q=0.7")
                    .addHeader("Sec-Ch-Ua", "\"Chromium\";v=\"122\", \"Not(A:Brand\";v=\"24\", \"Google Chrome\";v=\"122\"")
                    .addHeader("Sec-Ch-Ua-Mobile", "?0")
                    .addHeader("Sec-Ch-Ua-Platform", "\"Windows\"")
                    .addHeader("Sec-Fetch-Dest", "document")
                    .addHeader("Sec-Fetch-Mode", "navigate")
                    .addHeader("Sec-Fetch-Site", "none")
                    .addHeader("Sec-Fetch-User", "?1")
                    .addHeader("Upgrade-Insecure-Requests", "1")
                    .addHeader("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36")
                    .url(url)
                    .get()
                    .build();
            try (Response response = client.newCall(request).execute()) {
                if (response.isSuccessful() && response.body() != null) {
                    InputStream inputStream = response.body().byteStream();
                    BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream);
                    bufferedInputStream.mark(Integer.MAX_VALUE);
                    String mimeType = tika.detect(bufferedInputStream);
                    bufferedInputStream.reset();
                    if ("application/gzip".equals(mimeType)) {
                        // 创建GZIPInputStream来解压数据
                        GZIPInputStream gis = new GZIPInputStream(bufferedInputStream);
                        resultUrl = uploadFile(gis);
                    } else {
                        resultUrl = uploadFile(bufferedInputStream);
                    }
                }
            }
        }
        return resultUrl;
    }

    private String uploadFile(InputStream inputStream) throws IOException {
        MultipartFile cMultiFile = new MockMultipartFile("file", String.format("%s.png", RandomUtil.randomNumbers(6)), MediaType.MULTIPART_FORM_DATA_VALUE, inputStream);
        FileService fileService = fileServiceProvider.getFileService(fileServiceProvider.getActiveFileDriver());
        UpLoadFileVo upLoadFileVo = fileService.uploadFile(cMultiFile, null);
        if (upLoadFileVo != null && StringUtils.isNotEmpty(upLoadFileVo.getUrl())) {
            return upLoadFileVo.getUrl();
        }
        return null;
    }

    /**
     * base64转inputStream
     */
    private static InputStream base2InputStream(String base64string) {
        ByteArrayInputStream stream = null;
        try {
            byte[] bytes = Base64.getDecoder().decode(base64string);
            stream = new ByteArrayInputStream(bytes);
        } catch (Exception e) {
            return null;
        }
        return stream;
    }
}
